import numpy as np
import h5py as h5
import matplotlib.pyplot as plt

from matplotlib.lines import Line2D
from matplotlib import cm, gridspec
from matplotlib.colors import LinearSegmentedColormap, Colormap, to_rgb

plt.style.use("./paper.mplstyle")

_STYLE_DICT = {
    "bb": (0, (1, 3)),
    "ww": (0, (3, 3)),
    "tautau": (0, (5, 1)),
    "nuenue": (0, (3, 1, 1, 1)),
    "numunumu": (0, (3, 1, 1, 1, 1, 1)),
    "nutaunutau": (0, (3, 1, 1, 1, 1, 1, 1, 1)),
}

def initialize_args():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument(
        "--infile",
        default="./paper_plots.h5"
    )
    parser.add_argument(
        "--output_prefix",
        default=""
    )
    parser.add_argument(
        "--outdir",
        default="./"
    )
    return parser.parse_args()

### Fig. 1 ###

def _mindist1(meanx):
    distances = np.abs(meanx - np.expand_dims(meanx, axis=1))
    x = distances.flatten()
    x = x[x>0]
    mindistance = x.min()
    return mindistance

def _rotate_positions(positions: np.ndarray, phi: float) -> np.ndarray:
    m = np.array([
        [np.cos(phi),  np.sin(phi), 0],
        [-np.sin(phi), np.cos(phi), 0],
        [0,           0,            1],
    ])

    new_positions = np.matmul(positions, m)
    return new_positions

def _find_good_phi(positions, nshots=10_000, mindistf=_mindist1):
        
    maxmindistance = 0
    best_phi = np.nan

    np.random.seed(2)

    for _ in range(nshots):

        phi = np.random.uniform(0, np.pi)
        new_positions = _rotate_positions(positions, phi)
        meanx = np.mean(new_positions[:, :, 0], axis=1)
        mindistance = mindistf(meanx)
        if mindistance > maxmindistance:
            maxmindistance = mindistance
            best_phi = phi
                                                                                                        
    return best_phi

def _plot_detector(
    positions,
    is_dc,
    s=3,
    ax=None,
    phi=None,
    dc_color="dodgerblue",
    ic_color="crimson"
):
            
    if phi is None:
        phi = _find_good_phi(positions, nshots=10_000)

    plot_positions = _rotate_positions(positions, phi)

    if ax is None:
         _, ax = plt.subplots(figsize=(6, 5.27))

    ax.scatter(
        plot_positions[:, :, 0].flatten(),
        plot_positions[:, :, 2].flatten(),
        s=s,
        edgecolor="none",
        color=[dc_color if x else ic_color for x in is_dc.flatten()]
    )

    ax.set_ylim(-2483, -1400)
    ax.set_xlim(-623, 610)

    ax.set_ylabel(r"$z~\left[\mathrm{m}\right]$")
    ax.set_xlabel(r"$x~\left[\mathrm{m}\right]$")

    return ax

def _plot_effective_area(cents, hon, hit, ic_color="crimson", dc_color="dodgerblue", ax=None):
        
    if ax is None:
        _, ax = plt.subplots(figsize=(6, 5.27))

    ax.fill_between(
        cents,
        hon,
        hit+hon,
        step="mid",
        facecolor=ic_color,
        label="Main array"
    )
                                
    ax.fill_between(
        cents,
        1e-8,
        hon,
        step="mid",
        facecolor=dc_color,
        label="DeepCore"
    )

    ax.step(
        cents, 
        hit+hon,
        where="mid",
        color="k",
        label="Total"
    )

    ax.loglog()
    ax.legend()

    ax.set_xlim(2, 1e4)
    ax.set_ylim(1e-7, 3e1)
                                                    
    ax.set_xlabel(r"$E_{\nu}~\left[\mathrm{GeV}\right]$")
    ax.set_ylabel(r"$A_{\mathrm{eff.}}~\left[\mathrm{m}^{2}\right]$")
                                                                
    return ax

def make_detector_plot(h5filename, outdir, prefix, cs=None):

    if cs is None:
        cs = ["dodgerblue", "crimson"]
    fig = plt.figure(figsize=(12.2, 5.27))

    gs = fig.add_gridspec(
        1, 2,
        width_ratios=(1, 1),
        left=0.1, right=0.9, bottom=0.1, top=0.9,
        wspace=0.25, hspace=0.05
    )

    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1])

    with h5.File(h5filename) as h5f:
        gp = h5f["fig_1"]

        _plot_detector(
            gp["om_positions"][:, :, :],
            gp["is_dc"][:, :],
            ic_color=cs[1],
            dc_color=cs[0],
            ax=ax0
        )
        _plot_effective_area(
            gp["energies"][:],
            gp["oscnext_effa"][:],
            gp["northerntracks_effa"][:],
            ic_color=cs[1],
            dc_color=cs[0],
            ax=ax1
        )

    plt.savefig(f"{outdir}/{prefix}detector_effa.pdf")

### Fig. 2 ###

def _moving_average(a, n=6):
    ret = np.zeros(a.shape)
    for idx in range(len(ret)):
        if idx==0:
            ret[idx] = a[0]
            continue
        jdx = max(idx - n, 0)
        ret[idx] = np.mean(a[jdx:idx])
    return ret

def make_flux_plot(h5filename, outdir, prefix, cs=None):
    if cs is None:
        cs = ["red", "green", "blue"]
    alpha = 0.6
    
    handles = [
        Line2D([], [], label=r"$b\bar{b}$", ls="-", color="k"),
        Line2D([], [], label=r"$W^{+}W^{-}$", ls=":", color="k"),
        Line2D([], [], label=r"$\nu_{e}\bar{\nu}_{e}$", ls="--", color="k"),
        Line2D([], [], label=r"$10^{2}~\mathrm{GeV}$", color=cs[0]),
        Line2D([], [], label=r"$10^{3}~\mathrm{GeV}$", color=cs[1]),
        Line2D([], [], label=r"$10^{4}~\mathrm{GeV}$", color=cs[2]),
        Line2D([], [], label="W. EWC", color="k"),
        Line2D([], [], label="Wo. EWC", color=to_rgb("k") + (alpha,), lw=1),
        Line2D([], [], alpha=0.0),   
    ]
    
    _, ax = plt.subplots(figsize=(6, 5))

    with h5.File(h5filename) as h5f:
        gp = h5f["fig_2"]
        for channel, v in gp.items():
            for mass, v2 in v.items():
                mass = float(mass)
                ax.plot(
                    v2["ewc/energies"][:],
                    v2["ewc/energies"][:] * _moving_average(v2["ewc/numuflux"][:] + v2["ewc/numubarflux"][:]) * mass,
                    **v2["ewc"].attrs,
                )

                if channel=="12":
                    continue
                
                ax.plot(
                    v2["no_ewc/energies"][:],
                    v2["no_ewc/energies"][:] * _moving_average(v2["no_ewc/numuflux"][:] + v2["no_ewc/numubarflux"][:]) * mass,
                    **v2["no_ewc"].attrs
                )

    ax.set_ylim(2e-4, 20)
    ax.set_xlim(1, 3e3)

    ax.set_xlabel(r"$E_{\nu}~\left[\mathrm{GeV}\right]$")
    ax.set_ylabel(r"$E_{\nu}\frac{\mathrm{d}N_{\nu_{\mu}}}{\mathrm{d}E_{\nu}\mathrm{d}A\mathrm{d}t}~\left[\mathrm{m}^{-2}\,\mathrm{s}^{-1}\right]$")

    ax.legend(handles=handles, ncol=3, fontsize=12, frameon=False, loc=2)
    ax.loglog()
    plt.savefig(f"{outdir}/{prefix}fluxes_at_earth.pdf")

### Fig. 3 ###

def _make_signal_plot(osc_next_dist, pt_src_dist, ax=None, vmin=None, vmax=None, cax=None):
        
    if ax is None:
        _, ax = plt.subplots()

    z = (osc_next_dist.sum(axis=-1) + pt_src_dist).T[::-1]

    im = ax.imshow(
        np.log10(z),
        extent=[0, 180, -3, 6],
        aspect="auto",
        cmap="plasma",
        vmax=vmax,
        vmin=vmin,
    )

    ax.set_ylim(np.log10(5), np.log10(3e3))
    ax.set_xlim(0, 40)

    ax.set_yticks([1, 2, 3])
    ax.set_yticklabels([r"$10^{1}$", r"$10^{2}$", r"$10^{3}$"])
    ax.set_ylabel(r"$E_{\mathrm{reco}}~\left[\mathrm{GeV}\right]$")

    ax.set_xlabel(r"$\psi~\left[^{\circ}\right]$")
    ax.text(16, np.log10(1900), r"$\chi\chi\rightarrow b\bar{b},~m_{\chi}=200\,\mathrm{GeV}$")
    ax.text(24, np.log10(1100), r"$\sigma_{\chi N}^{\mathrm{SD}}=10^{-40}\,\mathrm{cm}^2$")

    cbar = plt.colorbar(im, label=r"$N_{\mathrm{evts}}$", cax=cax)

    cbar.set_ticks([-4, -3, -2, -1])
    cbar.set_ticklabels([r"$10^{-4}$", r"$10^{-3}$", r"$10^{-2}$", r"$10^{-1}$"])

    return ax

def _make_background_plot(osc_next_dist, pt_src_dist, ax=None, vmin=None, vmax=None, cax=None):

    if ax is None:
        _, ax = plt.subplots()

    z = (osc_next_dist.sum(axis=-1) + pt_src_dist).T[::-1]

    im = ax.imshow(
        np.log10(z),
        extent=[0, 180, -3, 6],
        aspect="auto",
        cmap="viridis",
        vmax=vmax,
        vmin=vmin,
    )

    ax.set_ylim(np.log10(5), np.log10(3e3))
    ax.set_xlim(0, 40)

    ax.set_yticks([1, 2, 3])
    ax.set_yticklabels([r"$10^{1}$", r"$10^{2}$", r"$10^{3}$"])
    ax.set_ylabel(r"$E_{\mathrm{reco}}~\left[\mathrm{GeV}\right]$")

    ax.set_xlabel(r"$\psi~\left[^{\circ}\right]$")

    cbar = plt.colorbar(im, label=r"$N_{\mathrm{evts}}$", cax=cax)

    cbar.set_ticks([-1, 0, 1, 2])
    cbar.set_ticklabels([r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$"])

    return ax

def make_sig_and_bg_plot(h5filename, outdir, prefix):
    with h5.File(h5filename) as h5f:
        osc_nom_sig = h5f["fig_3/oscnext_signal"][:, :, :]
        osc_nom_bg = h5f["fig_3/oscnext_background"][:, :, :]
        ps_nom_sig = h5f["fig_3/northerntracks_signal"][:, :]
        ps_nom_bg = h5f["fig_3/northerntracks_background"][:, :]

    fig = plt.figure(figsize=(12.2, 5.27))

    outer = fig.add_gridspec(
        1, 2,
        width_ratios=(1, 1),
        left=0.1, right=0.9, bottom=0.1, top=0.9,
        wspace=0.4, hspace=0.05
    )

    gs0 = gridspec.GridSpecFromSubplotSpec(1, 2, wspace=0.05, subplot_spec=outer[0], width_ratios=(6,0.3))
    gs1 = gridspec.GridSpecFromSubplotSpec(1, 2, wspace=0.05, subplot_spec=outer[1], width_ratios=(6,0.3))
    ax0 = fig.add_subplot(gs0[0])
    cax0 = fig.add_subplot(gs0[1])
    ax1 = fig.add_subplot(gs1[0])
    cax1 = fig.add_subplot(gs1[1])

    _make_signal_plot(osc_nom_sig, ps_nom_sig, ax=ax0, vmin=-4.67, cax=cax0)
    _make_background_plot(osc_nom_bg, ps_nom_bg, ax=ax1, cax=cax1, vmin=-1.9, vmax=2.4)

    plt.savefig(f"{outdir}/{prefix}signal_background_distirbution.pdf")

### Fig. 4 ###

def make_experimental_data_plot(h5filename, outdir, prefix):
    with h5.File(h5filename) as h5f:
        gp = h5f["fig_4"]
        out2 = gp["oscnext_unblinded_data"][:, :, :]
        out1 = gp["northerntracks_unblinded_data"][:, :]

    ax = _make_background_plot(out2, out1, vmin=-1.9, vmax=2.4)
    plt.savefig(f"{outdir}/{prefix}experimental_data.pdf")

### Fig. 5 ###

def _make_sd_limits_helper(
    limits,
    ax=None,
    lw_ic=2,
    lw_other=2,
    cmap1=cm.get_cmap("copper"),
    cmap2=cm.get_cmap("cool"),
    fontsize=10
):
        
    handles = [
        Line2D([],[], color="k", lw=lw_other, label=r"$b\bar{b}$", ls=_STYLE_DICT["bb"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{e}\bar{\nu}_{e}$", ls=_STYLE_DICT["nuenue"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$W^{+}W^{-}$", ls=_STYLE_DICT["ww"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{\mu}\bar{\nu}_{\mu}$", ls=_STYLE_DICT["numunumu"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\tau^{+}\tau^{-}$", ls=_STYLE_DICT["tautau"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{\tau}\bar{\nu}_{\tau}$", ls=_STYLE_DICT["nutaunutau"])
    ]

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    zorder = 10
    lw = lw_ic

    ctr_1 = 0
    ctr_2 = 0
    id_cs = []
    d_cs = []

    for idx, limit in enumerate(limits):

        kwargs = {"lw": lw_other}
        if idx==0:
            kwargs["zorder"] = 10
            kwargs["lw"] = lw_ic

        if not isinstance(limit, dict):
            if isinstance(cmap2, Colormap):
                color = cmap2(ctr_2 / (len(limits) - 4))
            else:
                color = cmap2[ctr_2]
            d_cs.append(color)
            ax.plot(limit[:, 0], limit[:, 1], c=color, **kwargs)
            ctr_2 += 1
        else:
            if isinstance(cmap1, Colormap):
                color = cmap1(ctr_1 / 2)
            else:
                color = cmap1[ctr_1]
            id_cs.append(color)
            for k, v in limit.items():
                ax.plot(v[:, 0], v[:, 1], ls=_STYLE_DICT[k], c=color, **kwargs)
            ctr_1 += 1
        lw = lw_other
        zorder -= 1

    ax.set_xlim(2, 1e4)
    ax.set_ylim(5e-44, 5e-36)
    ax.set_ylabel(r"$\sigma_{p\chi}^{\mathrm{SD}}~\left[\mathrm{cm^{2}}\right]$")

    ax.loglog()
     
    ax.text(
        10,
        3e-41,
        "LZ (2025)",
        color=d_cs[0],
        fontsize=fontsize,
        rotation=-70
    )
    
    ax.text(
        6.5,
        1.5e-39,
        "XENONnT (2023)",
        color=d_cs[1],
        fontsize=fontsize,
        rotation=-75,
    )
        
    ax.text(
        3,
        2e-39,
        "PICO (2019)",
        color=d_cs[2],
        fontsize=fontsize,
        rotation=-81,
    )
    
    ax.text(
        55,
        1.2e-38,
        "ANTARES (2016)",
        color=id_cs[2],
        fontsize=fontsize,
        rotation=-37,
    )

    ax.text(
        20,
        1.8e-39,
        "SUPER-K (2015)",
        color=id_cs[1],
        fontsize=fontsize,
        rotation=15,
    )

    ax.text(
        800,
        1e-41,
        "IC (This work)",
        color=id_cs[0],
        fontsize=fontsize,
        rotation=45,
    )

    plt.legend(
        handles=handles,
        ncol=3,
        loc="upper center",
        framealpha=0.0, 
    )
        
    return ax


def make_sd_limits_plot(h5filename, outdir, prefix):

    with h5.File(h5filename) as h5f:
        gp = h5f["fig_5"]
        sk_lims = {}
        antares_lims = {}
        ic_lims = {}
        for k, gp2 in gp["antares"].items():
            antares_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T
        for k, gp2 in gp["super_kamiokande"].items():
            sk_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T
        for k, gp2 in gp["icecube"].items():
            ic_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T

        pico_lims = np.array([gp["pico/masses"][:], gp["pico/limits"][:]]).T
        lux_lims = np.array([gp["lz_2024/masses"][:], gp["lz_2024/limits"][:]]).T
        xenon_lims = np.array([gp["xenon/masses"][:], gp["xenon/limits"][:]]).T

    limits = [ic_lims, lux_lims, xenon_lims, pico_lims, sk_lims, antares_lims]

    fig, ax = plt.subplots(figsize=(6, 7))
    _make_sd_limits_helper(
        limits,
        ax=ax,
        cmap1=["crimson", "forestgreen", "dodgerblue"],
        cmap2=cm.get_cmap("copper"),
    )
    plt.savefig(f"{outdir}/{prefix}sd_limits.pdf")

def make_solar_atm_plot(h5filename, outdir, prefix):
    with h5.File(h5filename) as h5f:
        gp = h5f["sup_fig_1"]
        a = gp["SIBYLL2.3_ppMRS_CombinedGHAndHG_H4a_le_interp"]
        b = gp["SIBYLL2.3_ppMRS_CombinedGHAndHG_H4a"]
        c = gp["beacom_upper"]
        d = gp["beacom_lower"]
        q = gp["edsjo"]

        colors = gp.attrs["colors"].split()

        fig, ax = plt.subplots()

        ax.fill_between(
            c["energies"][:],
            c["energies"][:]**2 * c["numuflux"][:],
            d["energies"][:]**2 * d["numuflux"][:],
            facecolor=to_rgb(colors[2]) + (0.2,),
            edgecolor=colors[2],
            linewidth=1,
            label="Ng, et al. (2017) uncertainty"
        )

        envelope = None
        for k, v in h5f["sup_fig_1"].items():
            if k in "beacom_lower beacom_upper edsjo SIBYLL2.3_ppMRS_CombinedGHAndHG_H4a_le_interp".split():
                continue
            es = v["energies"][:]
            if envelope is None:
                envelope = np.zeros((2,) + v["energies"].shape)
                envelope[0, :] = -np.inf
                envelope[1, :] = np.inf
                
            envelope[0, :] = np.where(v["numuflux"][:] > envelope[0, :], v["numuflux"][:], envelope[0, :])
            envelope[1, :] = np.where(v["numuflux"][:] < envelope[1, :], v["numuflux"][:], envelope[1, :])

        ax.fill_between(
            es,
            es**2 * envelope[0, :],
            es**2 * envelope[1, :],
            facecolor=to_rgb(colors[0]) + (0.2,),
            edgecolor=colors[0],
            linewidth=1,
            label="Argüelles, et al. (2017) uncertainty"
        )

        ax.plot(a["energies"][:], a["energies"][:]**2 * a["numuflux"][:], c=colors[0], ls="dotted")
        ax.plot(b["energies"][:], b["energies"][:]**2 * b["numuflux"][:], c=colors[0], label="Nominal model", zorder=100)
        ax.plot(q["energies"][:], q["energies"][:]**2 * q["numuflux"][:], label="Edsjö et al. (2017)")
        ax.plot(a["energies"][:], 2.48 * a["energies"][:]**2 * a["numuflux"][:], c=colors[1], label="Limit")

    ax.loglog()
    ax.set_ylabel(r"$E_{\nu}^2 \Phi_{\nu_{\mu}+\bar{\nu}_{\mu}}~\left[\mathrm{GeV}\,\mathrm{cm}^{-2}\,\mathrm{s}^{-1}\right]$")
    ax.set_ylim(7e-13, 2e-6)
    ax.set_yticks(np.logspace(-12, -6, 7))
    ax.set_yticklabels([r"$10^{%d}$" % x for x in np.linspace(-12, -6, 7)])

    ax.set_xlabel(r"$E_{\nu}~\left[\mathrm{GeV}\right]$")
    ax.set_xlim(3, 5e4)

    ax.legend(fontsize=12, loc=3, framealpha=False)

    plt.savefig(f"{outdir}/{prefix}sanu_comparison.pdf")

### Supl. Fig. 2 ###

def _si_limits_plot_helper(
    limits,
    ax=None,
    lw_ic=2,
    lw_other=2,
    cmap1=cm.get_cmap("copper"),
    cmap2=cm.get_cmap("cool"),
    fontsize=10
):
        
    handles = [
        Line2D([],[], color="k", lw=lw_other, label=r"$b\bar{b}$", ls=_STYLE_DICT["bb"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{e}\bar{\nu}_{e}$", ls=_STYLE_DICT["nuenue"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$W^{+}W^{-}$", ls=_STYLE_DICT["ww"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{\mu}\bar{\nu}_{\mu}$", ls=_STYLE_DICT["numunumu"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\tau^{+}\tau^{-}$", ls=_STYLE_DICT["tautau"]),
        Line2D([],[], color="k", lw=lw_other, label=r"$\nu_{\tau}\bar{\nu}_{\tau}$", ls=_STYLE_DICT["nutaunutau"])
    ]

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    zorder = 10
    lw = lw_ic

    ctr_1 = 0
    ctr_2 = 0
    id_cs = []
    d_cs = []

    for idx, limit in enumerate(limits):

        kwargs = {"lw": lw_other}
        if idx==0:
            kwargs["zorder"] = 10
            kwargs["lw"] = lw_ic

        if not isinstance(limit, dict):
            if isinstance(cmap2, Colormap):
                color = cmap2(ctr_2 / (len(limits) - 4))
            else:
                color = cmap2[ctr_2]
            d_cs.append(color)
            ax.plot(limit[:, 0], limit[:, 1], c=color, **kwargs)
            ctr_2 += 1
        else:
            if isinstance(cmap1, Colormap):
                color = cmap1(ctr_1 / 2)
            else:
                color = cmap1[ctr_1]
            id_cs.append(color)
            for k, v in limit.items():
                ax.plot(v[:, 0], v[:, 1], ls=_STYLE_DICT[k], c=color, **kwargs)
            ctr_1 += 1
        lw = lw_other
        zorder -= 1

    ax.set_xlim(4, 1e4)
    ax.set_ylim(1.5e-49, 2e-41)

    ax.set_xlabel(r"$m_{\chi}~\left[\mathrm{GeV}\right]$")
    ax.set_ylabel(r"$\sigma^{\mathrm{SI}}_{p\chi}~\left[\mathrm{cm^{2}}\right]$")

    ax.loglog()
    
    ax.text(
        100,
        4.0e-48,
        "LZ (2025)",
        color=d_cs[0],
        fontsize=fontsize,
        rotation=28
    )
    
    ax.text(
        50,
        4.0e-47,
        "XENONnT (2023)",
        color=d_cs[1],
        fontsize=fontsize,
        rotation=25,
    )
        
    ax.text(
        4.5,
        2e-43,
        "PICO (2019)",
        color=d_cs[2],
        fontsize=fontsize,
        rotation=-62,
    )

    ax.text(
        800,
        4.0e-43,
        "ANTARES (2016)",
        color=id_cs[2],
        fontsize=fontsize,
        rotation=35,
    )

    ax.text(
        4.5,
        1.4e-42,
        "SUPER-K (2015)",
        color=id_cs[1],
        fontsize=fontsize,
        rotation=-33,
    )

    ax.text(
        800,
        1e-44,
        "IC (This work)",
        color=id_cs[0],
        fontsize=fontsize,
        rotation=45,    
    )

    plt.legend(
        handles=handles,
        ncol=3,
        loc="lower center",
        framealpha=0.0, 
    )

    return ax

def make_si_limits_plot(h5filename, outdir, prefix):

    with h5.File(h5filename) as h5f:
        gp = h5f["sup_fig_2"]
        sk_lims = {}
        antares_lims = {}
        ic_lims = {}
        for k, gp2 in gp["antares"].items():
            antares_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T
        for k, gp2 in gp["super_kamiokande"].items():
            sk_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T
        for k, gp2 in gp["icecube"].items():
            ic_lims[k] = np.array([gp2["masses"][:], gp2["limits"][:]]).T

        pico_lims = np.array([gp["pico/masses"][:], gp["pico/limits"][:]]).T
        lux_lims = np.array([gp["lz_2024/masses"][:], gp["lz_2024/limits"][:]]).T
        xenon_lims = np.array([gp["xenon/masses"][:], gp["xenon/limits"][:]]).T

    limits = [ic_lims, lux_lims, xenon_lims, pico_lims, sk_lims, antares_lims]

    fig, ax = plt.subplots(figsize=(6, 7))
    _si_limits_plot_helper(
        limits,
        ax=ax,
        cmap1=["crimson", "forestgreen", "dodgerblue"],
        cmap2=cm.get_cmap("copper"),
    )
    plt.savefig(f"{outdir}/{prefix}si_limits.pdf")

### Supl. Fig. 3 ###

def make_electron_limits_plot(h5filename, outdir, prefix):
    cmap = cm.get_cmap("copper")

    handles = [
        Line2D([],[], color="k", lw=2, label=r"$b\bar{b}$", ls=_STYLE_DICT["bb"]),
        Line2D([],[], color="k", lw=2, label=r"$\nu_{e}\bar{\nu}_{e}$", ls=_STYLE_DICT["nuenue"]),
        Line2D([],[], color="k", lw=2, label=r"$W^{+}W^{-}$", ls=_STYLE_DICT["ww"]),
        Line2D([],[], color="k", lw=2, label=r"$\nu_{\mu}\bar{\nu}_{\mu}$", ls=_STYLE_DICT["numunumu"]),
        Line2D([],[], color="k", lw=2, label=r"$\tau^{+}\tau^{-}$", ls=_STYLE_DICT["tautau"]),
        Line2D([],[], color="k", lw=2, label=r"$\nu_{\tau}\bar{\nu}_{\tau}$", ls=_STYLE_DICT["nutaunutau"])
    ]

    fig, ax = plt.subplots(figsize=(6, 7))

    with h5.File(h5filename) as h5f:
            
        ms = np.logspace(1, np.log10(300))
        lim0 = h5f["sup_fig_3/xenon/limits"][-1]
        m0 = h5f["sup_fig_3/xenon/masses"][-1]
        ax.plot(ms, ms / m0 * lim0, color=cmap(0.5), alpha=0.4)
                            
        gp = h5f["sup_fig_3/icecube"]
        for k, v in gp.items():
            ax.plot(v["masses"][:], v["limits"][:], color="crimson", ls=_STYLE_DICT[k])
        ax.plot(h5f["sup_fig_3/xenon/masses"][:], h5f["sup_fig_3/xenon/limits"][:], color=cmap(0.5))

    ax.loglog()
            
    ax.set_xlabel(r"$m_{\chi}~\left[\mathrm{GeV}\right]$")
    ax.set_xlim(0.5, 1e4)    

    ax.set_ylabel(r"$\sigma_{e\chi}~\left[\mathrm{cm^{2}}\right]$")
    ax.set_ylim(4e-42, 6e-37)

    ax.text(
        1,
        3.0e-40,
        "XENON1T (2021)",
        color=cmap(0.5),
        fontsize=10,
        rotation=42,
    )

    ax.text(
        800,
        1.2e-39,
        "IC (This work)",
        color="crimson",
        fontsize=10,
        rotation=62,    
    )

    plt.legend(
        handles=handles,
        ncol=3,
        loc="upper center",
        framealpha=0.0, 
    )

    plt.savefig(f"{outdir}/{prefix}electron_limits.pdf")

def main(args=None):
    if args is None:
        args = initialize_args()

    prefix = args.output_prefix
    if prefix and not prefix.endswith("_"):
        prefix += "_"

    make_experimental_data_plot(args.infile, args.outdir, prefix)
    make_electron_limits_plot(args.infile, args.outdir, prefix)
    make_si_limits_plot(args.infile, args.outdir, prefix)
    make_detector_plot(args.infile, args.outdir, prefix)
    make_flux_plot(args.infile, args.outdir, prefix)
    make_sig_and_bg_plot(args.infile, args.outdir, prefix)
    make_sd_limits_plot(args.infile, args.outdir, prefix)
    make_solar_atm_plot(args.infile, args.outdir, prefix)

if __name__=="__main__":
    main()
